##
## This file is part of the libsigrokdecode project.
##
## Copyright (C) 2014 Daniel Elstner <daniel.kitta@gmail.com>
##
## This program is free software; you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation; either version 3 of the License, or
## (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
## GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program; if not, see <http://www.gnu.org/licenses/>.
##

import sigrokdecode as srd
from functools import reduce
from .tables import instr_table_by_prefix
import string

class Ann:
    ADDR, MEMRD, MEMWR, IORD, IOWR, INSTR, ROP, WOP, WARN = range(9)
class Row:
    ADDRBUS, DATABUS, INSTRUCTIONS, OPERANDS, WARNINGS = range(5)
class Pin:
    D0, D7 = 0, 7
    M1, RD, WR, MREQ, IORQ = range(8, 13)
    A0, A15 = 13, 28
class Cycle:
    NONE, MEMRD, MEMWR, IORD, IOWR, FETCH, INTACK = range(7)

# Provide custom format type 'H' for hexadecimal output
# with leading decimal digit (assembler syntax).
class AsmFormatter(string.Formatter):
    def format_field(self, value, format_spec):
        if format_spec.endswith('H'):
            result = format(value, format_spec[:-1] + 'X')
            return result if result[0] in string.digits else '0' + result
        else:
            return format(value, format_spec)

formatter = AsmFormatter()

ann_data_cycle_map = {
    Cycle.MEMRD:  Ann.MEMRD,
    Cycle.MEMWR:  Ann.MEMWR,
    Cycle.IORD:   Ann.IORD,
    Cycle.IOWR:   Ann.IOWR,
    Cycle.FETCH:  Ann.MEMRD,
    Cycle.INTACK: Ann.IORD,
}

def reduce_bus(bus):
    if 0xFF in bus:
        return None # unassigned bus channels
    else:
        return reduce(lambda a, b: (a << 1) | b, reversed(bus))

def signed_byte(byte):
    return byte if byte < 128 else byte - 256

class Decoder(srd.Decoder):
    api_version = 3
    id       = 'z80'
    name     = 'Z80'
    longname = 'Zilog Z80 CPU'
    desc     = 'Zilog Z80 microprocessor disassembly.'
    license  = 'gplv3+'
    inputs   = ['logic']
    outputs  = []
    tags     = ['Retro computing']
    channels = tuple({
            'id': 'd%d' % i,
            'name': 'D%d' % i,
            'desc': 'Data bus line %d' % i
            } for i in range(8)
    ) + (
        {'id': 'm1', 'name': '/M1', 'desc': 'Machine cycle 1'},
        {'id': 'rd', 'name': '/RD', 'desc': 'Memory or I/O read'},
        {'id': 'wr', 'name': '/WR', 'desc': 'Memory or I/O write'},
    )
    optional_channels = (
        {'id': 'mreq', 'name': '/MREQ', 'desc': 'Memory request'},
        {'id': 'iorq', 'name': '/IORQ', 'desc': 'I/O request'},
    ) + tuple({
        'id': 'a%d' % i,
        'name': 'A%d' % i,
        'desc': 'Address bus line %d' % i
        } for i in range(16)
    )
    annotations = (
        ('addr',  'Memory or I/O address'),
        ('memrd', 'Byte read from memory'),
        ('memwr', 'Byte written to memory'),
        ('iord',  'Byte read from I/O port'),
        ('iowr',  'Byte written to I/O port'),
        ('instr', 'Z80 CPU instruction'),
        ('rop',   'Value of input operand'),
        ('wop',   'Value of output operand'),
        ('warn',  'Warning message'),
    )
    annotation_rows = (
        ('addrbus', 'Address bus', (Ann.ADDR,)),
        ('databus', 'Data bus', (Ann.MEMRD, Ann.MEMWR, Ann.IORD, Ann.IOWR)),
        ('instructions', 'Instructions', (Ann.INSTR,)),
        ('operands', 'Operands', (Ann.ROP, Ann.WOP)),
        ('warnings', 'Warnings', (Ann.WARN,))
    )

    def __init__(self):
        self.reset()

    def reset(self):
        self.prev_cycle = Cycle.NONE
        self.op_state   = self.state_IDLE

    def start(self):
        self.out_ann    = self.register(srd.OUTPUT_ANN)
        self.bus_data   = None
        self.samplenum  = None
        self.addr_start = None
        self.data_start = None
        self.dasm_start = None
        self.pend_addr  = None
        self.pend_data  = None
        self.ann_data   = None
        self.ann_dasm   = None
        self.prev_cycle = Cycle.NONE
        self.op_state   = self.state_IDLE
        self.instr_len  = 0

    def decode(self):
        while True:
            # TODO: Come up with more appropriate self.wait() conditions.
            (d0, d1, d2, d3, d4, d5, d6, d7, m1, rd, wr, mreq, iorq, a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15) = self.wait()
            pins = (d0, d1, d2, d3, d4, d5, d6, d7, m1, rd, wr, mreq, iorq, a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15)
            cycle = Cycle.NONE
            if pins[Pin.MREQ] != 1: # default to asserted
                if pins[Pin.RD] == 0:
                    cycle = Cycle.FETCH if pins[Pin.M1] == 0 else Cycle.MEMRD
                elif pins[Pin.WR] == 0:
                    cycle = Cycle.MEMWR
            elif pins[Pin.IORQ] == 0: # default to not asserted
                if pins[Pin.M1] == 0:
                    cycle = Cycle.INTACK
                elif pins[Pin.RD] == 0:
                    cycle = Cycle.IORD
                elif pins[Pin.WR] == 0:
                    cycle = Cycle.IOWR

            if cycle != Cycle.NONE:
                self.bus_data = reduce_bus(pins[Pin.D0:Pin.D7+1])
            if cycle != self.prev_cycle:
                if self.prev_cycle == Cycle.NONE:
                    self.on_cycle_begin(reduce_bus(pins[Pin.A0:Pin.A15+1]))
                elif cycle == Cycle.NONE:
                    self.on_cycle_end()
                else:
                    self.on_cycle_trans()
            self.prev_cycle = cycle

    def on_cycle_begin(self, bus_addr):
        if self.pend_addr is not None:
            self.put_text(self.addr_start, Ann.ADDR,
                          '{:04X}'.format(self.pend_addr))
        self.addr_start = self.samplenum
        self.pend_addr  = bus_addr

    def on_cycle_end(self):
        self.instr_len += 1
        self.op_state = self.op_state()
        if self.ann_dasm is not None:
            self.put_disasm()
        if self.op_state == self.state_RESTART:
            self.op_state = self.state_IDLE()

        if self.ann_data is not None:
            self.put_text(self.data_start, self.ann_data,
                          '{:02X}'.format(self.pend_data))
        self.data_start = self.samplenum
        self.pend_data  = self.bus_data
        self.ann_data   = ann_data_cycle_map[self.prev_cycle]

    def on_cycle_trans(self):
        self.put_text(self.samplenum - 1, Ann.WARN,
                      'Illegal transition between control states')
        self.pend_addr = None
        self.ann_data  = None
        self.ann_dasm  = None

    def put_disasm(self):
        text = formatter.format(self.mnemonic, r=self.arg_reg, d=self.arg_dis,
                                j=self.arg_dis+self.instr_len, i=self.arg_imm,
                                ro=self.arg_read, wo=self.arg_write)
        self.put_text(self.dasm_start, self.ann_dasm, text)
        self.ann_dasm   = None
        self.dasm_start = self.samplenum

    def put_text(self, ss, ann_idx, ann_text):
        self.put(ss, self.samplenum, self.out_ann, [ann_idx, [ann_text]])

    def state_RESTART(self):
        return self.state_IDLE

    def state_IDLE(self):
        if self.prev_cycle != Cycle.FETCH:
            return self.state_IDLE
        self.want_dis   = 0
        self.want_imm   = 0
        self.want_read  = 0
        self.want_write = 0
        self.want_wr_be = False
        self.op_repeat  = False
        self.arg_dis    = 0
        self.arg_imm    = 0
        self.arg_read   = 0
        self.arg_write  = 0
        self.arg_reg    = ''
        self.mnemonic   = ''
        self.instr_pend = False
        self.read_pend  = False
        self.write_pend = False
        self.dasm_start = self.samplenum
        self.op_prefix  = 0
        self.instr_len  = 0
        if self.bus_data in (0xCB, 0xED, 0xDD, 0xFD):
            return self.state_PRE1
        else:
            return self.state_OPCODE

    def state_PRE1(self):
        if self.prev_cycle != Cycle.FETCH:
            self.mnemonic = 'Prefix not followed by fetch'
            self.ann_dasm = Ann.WARN
            return self.state_RESTART
        self.op_prefix = self.pend_data
        if self.op_prefix in (0xDD, 0xFD):
            if self.bus_data == 0xCB:
                return self.state_PRE2
            if self.bus_data in (0xDD, 0xED, 0xFD):
                return self.state_PRE1
        return self.state_OPCODE

    def state_PRE2(self):
        if self.prev_cycle != Cycle.MEMRD:
            self.mnemonic = 'Missing displacement'
            self.ann_dasm = Ann.WARN
            return self.state_RESTART
        self.op_prefix = (self.op_prefix << 8) | self.pend_data
        return self.state_PREDIS

    def state_PREDIS(self):
        if self.prev_cycle != Cycle.MEMRD:
            self.mnemonic = 'Missing opcode'
            self.ann_dasm = Ann.WARN
            return self.state_RESTART
        self.arg_dis = signed_byte(self.pend_data)
        return self.state_OPCODE

    def state_OPCODE(self):
        (table, self.arg_reg) = instr_table_by_prefix[self.op_prefix]
        self.op_prefix = 0
        instruction = table.get(self.pend_data, None)
        if instruction is None:
            self.mnemonic = 'Invalid instruction'
            self.ann_dasm = Ann.WARN
            return self.state_RESTART
        (self.want_dis, self.want_imm, self.want_read, want_write,
                self.op_repeat, self.mnemonic) = instruction
        self.want_write = abs(want_write)
        self.want_wr_be = (want_write < 0)
        if self.want_dis > 0:
            return self.state_POSTDIS
        if self.want_imm > 0:
            return self.state_IMM1
        self.ann_dasm = Ann.INSTR
        if self.want_read > 0 and self.prev_cycle in (Cycle.MEMRD, Cycle.IORD):
            return self.state_ROP1
        if self.want_write > 0 and self.prev_cycle in (Cycle.MEMWR, Cycle.IOWR):
            return self.state_WOP1
        return self.state_RESTART

    def state_POSTDIS(self):
        self.arg_dis = signed_byte(self.pend_data)
        if self.want_imm > 0:
            return self.state_IMM1
        self.ann_dasm = Ann.INSTR
        if self.want_read > 0 and self.prev_cycle in (Cycle.MEMRD, Cycle.IORD):
            return self.state_ROP1
        if self.want_write > 0 and self.prev_cycle in (Cycle.MEMWR, Cycle.IOWR):
            return self.state_WOP1
        return self.state_RESTART

    def state_IMM1(self):
        self.arg_imm = self.pend_data
        if self.want_imm > 1:
            return self.state_IMM2
        self.ann_dasm = Ann.INSTR
        if self.want_read > 0 and self.prev_cycle in (Cycle.MEMRD, Cycle.IORD):
            return self.state_ROP1
        if self.want_write > 0 and self.prev_cycle in (Cycle.MEMWR, Cycle.IOWR):
            return self.state_WOP1
        return self.state_RESTART

    def state_IMM2(self):
        self.arg_imm |= self.pend_data << 8
        self.ann_dasm = Ann.INSTR
        if self.want_read > 0 and self.prev_cycle in (Cycle.MEMRD, Cycle.IORD):
            return self.state_ROP1
        if self.want_write > 0 and self.prev_cycle in (Cycle.MEMWR, Cycle.IOWR):
            return self.state_WOP1
        return self.state_RESTART

    def state_ROP1(self):
        self.arg_read = self.pend_data
        if self.want_read < 2:
            self.mnemonic = '{ro:02X}'
            self.ann_dasm = Ann.ROP
        if self.want_write > 0:
            return self.state_WOP1
        if self.want_read > 1:
            return self.state_ROP2
        if self.op_repeat and self.prev_cycle in (Cycle.MEMRD, Cycle.IORD):
            return self.state_ROP1
        return self.state_RESTART

    def state_ROP2(self):
        self.arg_read |= self.pend_data << 8
        self.mnemonic = '{ro:04X}'
        self.ann_dasm = Ann.ROP
        if self.want_write > 0 and self.prev_cycle in (Cycle.MEMWR, Cycle.IOWR):
            return self.state_WOP1
        return self.state_RESTART

    def state_WOP1(self):
        self.arg_write = self.pend_data
        if self.want_read > 1:
            return self.state_ROP2
        if self.want_write > 1:
            return self.state_WOP2
        self.mnemonic = '{wo:02X}'
        self.ann_dasm = Ann.WOP
        if self.want_read > 0 and self.op_repeat and \
                self.prev_cycle in (Cycle.MEMRD, Cycle.IORD):
            return self.state_ROP1
        return self.state_RESTART

    def state_WOP2(self):
        if self.want_wr_be:
            self.arg_write = (self.arg_write << 8) | self.pend_data
        else:
            self.arg_write |= self.pend_data << 8
        self.mnemonic = '{wo:04X}'
        self.ann_dasm = Ann.WOP
        return self.state_RESTART
